<!DOCTYPE html>
<html lang="zh-cn" dir="ltr">
    <head><meta charset='utf-8'>
<meta name='viewport' content='width=device-width, initial-scale=1'><meta name='description' content='孪生网络之小样本学习： DL标准分类方法：输入x通过多层神经网络后，输出x属于某一类（或多个类）的概率。 小样本分类方法：每个类只需要为数不多的训练样本，要求要识别的样本与训练样本相似，如人脸识别。 孪生网络 孪生网络和学习输入的特征进行分类的网络不同，孪生网络在两个输入中进行区分，即是学习两个输入的相似度。
孪生网络由两个完全相同的神经网络组成，每个都采用两个输入图像中的一个。然后将两个网络的最后一层馈送到对比损失函数，用来计算两个图像之间的相似度。它具有两个姐妹网络，它们是具有完全相同权重的相同神经网络。图像对中的每个图像将被馈送到这些网络中的一个。使用对比损失函数优化网络（我们将获得确切的函数）
损失函数： 使用和普通神经网络所不同的损失函数：对比损失函数（Contrastive Loss Function） 孪生架构的目的不是对输入特征进行分类，而是区分它们。因此，分类损失函数（如交叉熵）不是最合适的选择。相反，这种架构更适合使用对比函数。一般而言，这个函数只是评估网络区分一对给定的图像的效果如何。 $$ L = \frac{1} {2N} \sum_{i=1}^{n} yd^2 &#43; (1-y)max(margin-d, 0)^2 $$
其中$d=||an−bn||_2$，代表两个样本特征的欧氏距离，y为两个样本是否匹配的标签，y=1代表两个样本相似或者匹配，y=0则代表不匹配，margin为设定的阈值。 欧氏距离d： $$ d=\sqrt{(G_w(x_1) - G_w(x_2))^2} $$ 其中$G_w$是其中一个姐妹网络的输出。$X1$和$X2$是输入数据对。 说明： Y值为1或0。如果模型预测输入是相似的，那么Y的值为0，否则Y为1。 max（）是表示0和m-Dw之间较大值的函数。 m是大于0的边际价值（margin value）。有一个边际价值表示超出该边际价值的不同对不会造成损失。这是有道理的，因为你只希望基于实际不相似对来优化网络，但网络认为是相当相似的。
处理多分类、少样本的问题： 训练数据要保证基本正、负样本比例接近1:1 数据集划分：训练集，测试集，支持集（Support Set），其中支持集包含所有类的样本
训练： x1、x2可以是正样本、负样本
样本x1通过网络得到输出y1 样本x2通过网络得到输出y2 使用y1、y2计算对比损失 反向传播计算梯度 使用优化器更新权重 训练集中，x1、x2和标签y的对应关系 x1和x2的关系 y x1、x2属于同一类 0 x1、x2属于不同类 1 测试： 给定测试样本 x ，从支持集中依次取出每个类 x_i（ i=1,2,3,&amp;hellip;.n ），x 和所有 x_i 依次通过孪生网络，若 x_j 和 x 的相似度最高，则认为 x 属于第 j 类（即 x_j 所属的类）'><title>孪生网络实现多分类（Siamese Network）</title>

<link rel='canonical' href='https://charent.github.io/p/%E5%AD%AA%E7%94%9F%E7%BD%91%E7%BB%9C%E5%AE%9E%E7%8E%B0%E5%A4%9A%E5%88%86%E7%B1%BBsiamese-network/'>

<link rel="stylesheet" href="/scss/style.min.65463afd28d606277b44441c8bbb8b0277823a2d0c03ab8ba9d0567d274f7b43.css"><meta property='og:title' content='孪生网络实现多分类（Siamese Network）'>
<meta property='og:description' content='孪生网络之小样本学习： DL标准分类方法：输入x通过多层神经网络后，输出x属于某一类（或多个类）的概率。 小样本分类方法：每个类只需要为数不多的训练样本，要求要识别的样本与训练样本相似，如人脸识别。 孪生网络 孪生网络和学习输入的特征进行分类的网络不同，孪生网络在两个输入中进行区分，即是学习两个输入的相似度。
孪生网络由两个完全相同的神经网络组成，每个都采用两个输入图像中的一个。然后将两个网络的最后一层馈送到对比损失函数，用来计算两个图像之间的相似度。它具有两个姐妹网络，它们是具有完全相同权重的相同神经网络。图像对中的每个图像将被馈送到这些网络中的一个。使用对比损失函数优化网络（我们将获得确切的函数）
损失函数： 使用和普通神经网络所不同的损失函数：对比损失函数（Contrastive Loss Function） 孪生架构的目的不是对输入特征进行分类，而是区分它们。因此，分类损失函数（如交叉熵）不是最合适的选择。相反，这种架构更适合使用对比函数。一般而言，这个函数只是评估网络区分一对给定的图像的效果如何。 $$ L = \frac{1} {2N} \sum_{i=1}^{n} yd^2 &#43; (1-y)max(margin-d, 0)^2 $$
其中$d=||an−bn||_2$，代表两个样本特征的欧氏距离，y为两个样本是否匹配的标签，y=1代表两个样本相似或者匹配，y=0则代表不匹配，margin为设定的阈值。 欧氏距离d： $$ d=\sqrt{(G_w(x_1) - G_w(x_2))^2} $$ 其中$G_w$是其中一个姐妹网络的输出。$X1$和$X2$是输入数据对。 说明： Y值为1或0。如果模型预测输入是相似的，那么Y的值为0，否则Y为1。 max（）是表示0和m-Dw之间较大值的函数。 m是大于0的边际价值（margin value）。有一个边际价值表示超出该边际价值的不同对不会造成损失。这是有道理的，因为你只希望基于实际不相似对来优化网络，但网络认为是相当相似的。
处理多分类、少样本的问题： 训练数据要保证基本正、负样本比例接近1:1 数据集划分：训练集，测试集，支持集（Support Set），其中支持集包含所有类的样本
训练： x1、x2可以是正样本、负样本
样本x1通过网络得到输出y1 样本x2通过网络得到输出y2 使用y1、y2计算对比损失 反向传播计算梯度 使用优化器更新权重 训练集中，x1、x2和标签y的对应关系 x1和x2的关系 y x1、x2属于同一类 0 x1、x2属于不同类 1 测试： 给定测试样本 x ，从支持集中依次取出每个类 x_i（ i=1,2,3,&amp;hellip;.n ），x 和所有 x_i 依次通过孪生网络，若 x_j 和 x 的相似度最高，则认为 x 属于第 j 类（即 x_j 所属的类）'>
<meta property='og:url' content='https://charent.github.io/p/%E5%AD%AA%E7%94%9F%E7%BD%91%E7%BB%9C%E5%AE%9E%E7%8E%B0%E5%A4%9A%E5%88%86%E7%B1%BBsiamese-network/'>
<meta property='og:site_name' content='Charent的博客'>
<meta property='og:type' content='article'><meta property='article:section' content='Post' /><meta property='article:published_time' content='2020-03-05T00:00:00&#43;00:00'/><meta property='article:modified_time' content='2020-03-05T00:00:00&#43;00:00'/>
<meta name="twitter:title" content="孪生网络实现多分类（Siamese Network）">
<meta name="twitter:description" content="孪生网络之小样本学习： DL标准分类方法：输入x通过多层神经网络后，输出x属于某一类（或多个类）的概率。 小样本分类方法：每个类只需要为数不多的训练样本，要求要识别的样本与训练样本相似，如人脸识别。 孪生网络 孪生网络和学习输入的特征进行分类的网络不同，孪生网络在两个输入中进行区分，即是学习两个输入的相似度。
孪生网络由两个完全相同的神经网络组成，每个都采用两个输入图像中的一个。然后将两个网络的最后一层馈送到对比损失函数，用来计算两个图像之间的相似度。它具有两个姐妹网络，它们是具有完全相同权重的相同神经网络。图像对中的每个图像将被馈送到这些网络中的一个。使用对比损失函数优化网络（我们将获得确切的函数）
损失函数： 使用和普通神经网络所不同的损失函数：对比损失函数（Contrastive Loss Function） 孪生架构的目的不是对输入特征进行分类，而是区分它们。因此，分类损失函数（如交叉熵）不是最合适的选择。相反，这种架构更适合使用对比函数。一般而言，这个函数只是评估网络区分一对给定的图像的效果如何。 $$ L = \frac{1} {2N} \sum_{i=1}^{n} yd^2 &#43; (1-y)max(margin-d, 0)^2 $$
其中$d=||an−bn||_2$，代表两个样本特征的欧氏距离，y为两个样本是否匹配的标签，y=1代表两个样本相似或者匹配，y=0则代表不匹配，margin为设定的阈值。 欧氏距离d： $$ d=\sqrt{(G_w(x_1) - G_w(x_2))^2} $$ 其中$G_w$是其中一个姐妹网络的输出。$X1$和$X2$是输入数据对。 说明： Y值为1或0。如果模型预测输入是相似的，那么Y的值为0，否则Y为1。 max（）是表示0和m-Dw之间较大值的函数。 m是大于0的边际价值（margin value）。有一个边际价值表示超出该边际价值的不同对不会造成损失。这是有道理的，因为你只希望基于实际不相似对来优化网络，但网络认为是相当相似的。
处理多分类、少样本的问题： 训练数据要保证基本正、负样本比例接近1:1 数据集划分：训练集，测试集，支持集（Support Set），其中支持集包含所有类的样本
训练： x1、x2可以是正样本、负样本
样本x1通过网络得到输出y1 样本x2通过网络得到输出y2 使用y1、y2计算对比损失 反向传播计算梯度 使用优化器更新权重 训练集中，x1、x2和标签y的对应关系 x1和x2的关系 y x1、x2属于同一类 0 x1、x2属于不同类 1 测试： 给定测试样本 x ，从支持集中依次取出每个类 x_i（ i=1,2,3,&amp;hellip;.n ），x 和所有 x_i 依次通过孪生网络，若 x_j 和 x 的相似度最高，则认为 x 属于第 j 类（即 x_j 所属的类）">
    </head>
    <body class="
    article-page
    ">
    <script>
        (function() {
            const colorSchemeKey = 'StackColorScheme';
            if(!localStorage.getItem(colorSchemeKey)){
                localStorage.setItem(colorSchemeKey, "auto");
            }
        })();
    </script><script>
    (function() {
        const colorSchemeKey = 'StackColorScheme';
        const colorSchemeItem = localStorage.getItem(colorSchemeKey);
        const supportDarkMode = window.matchMedia('(prefers-color-scheme: dark)').matches === true;

        if (colorSchemeItem == 'dark' || colorSchemeItem === 'auto' && supportDarkMode) {
            

            document.documentElement.dataset.scheme = 'dark';
        } else {
            document.documentElement.dataset.scheme = 'light';
        }
    })();
</script>
<div class="container main-container flex on-phone--column extended"><aside class="sidebar left-sidebar sticky ">
    <button class="hamburger hamburger--spin" type="button" id="toggle-menu" aria-label="切换菜单">
        <span class="hamburger-box">
            <span class="hamburger-inner"></span>
        </span>
    </button>

    <header>
        
            
            <figure class="site-avatar">
                <a href="/">
                
                    
                    
                    
                        
                        <img src="/img/avatar_hu544a9780599d84daff7ca906f17dd84c_13353_300x0_resize_box_3.png" width="300"
                            height="300" class="site-logo" loading="lazy" alt="Avatar">
                    
                
                </a>
                
                    <span class="emoji">😆</span>
                
            </figure>
            
        
        
        <div class="site-meta">
            <h1 class="site-name"><a href="/">Charent的博客</a></h1>
            <h2 class="site-description">不积硅步，无以至千里</h2>
        </div>
    </header><ol class="social-menu">
            
                <li>
                    <a 
                        href='https://github.com/charent'
                        target="_blank"
                        title="GitHub"
                    >
                        
                        
                            <svg xmlns="http://www.w3.org/2000/svg" class="icon icon-tabler icon-tabler-brand-github" width="24" height="24" viewBox="0 0 24 24" stroke-width="2" stroke="currentColor" fill="none" stroke-linecap="round" stroke-linejoin="round">
  <path stroke="none" d="M0 0h24v24H0z" fill="none"/>
  <path d="M9 19c-4.3 1.4 -4.3 -2.5 -6 -3m12 5v-3.5c0 -1 .1 -1.4 -.5 -2c2.8 -.3 5.5 -1.4 5.5 -6a4.6 4.6 0 0 0 -1.3 -3.2a4.2 4.2 0 0 0 -.1 -3.2s-1.1 -.3 -3.5 1.3a12.3 12.3 0 0 0 -6.2 0c-2.4 -1.6 -3.5 -1.3 -3.5 -1.3a4.2 4.2 0 0 0 -.1 3.2a4.6 4.6 0 0 0 -1.3 3.2c0 4.6 2.7 5.7 5.5 6c-.6 .6 -.6 1.2 -.5 2v3.5" />
</svg>



                        
                    </a>
                </li>
            
        </ol><ol class="menu" id="main-menu">
        
        
        

        <li >
            <a href='/' >
                
                
                
                    <svg xmlns="http://www.w3.org/2000/svg" class="icon icon-tabler icon-tabler-home" width="24" height="24" viewBox="0 0 24 24" stroke-width="2" stroke="currentColor" fill="none" stroke-linecap="round" stroke-linejoin="round">
  <path stroke="none" d="M0 0h24v24H0z"/>
  <polyline points="5 12 3 12 12 3 21 12 19 12" />
  <path d="M5 12v7a2 2 0 0 0 2 2h10a2 2 0 0 0 2 -2v-7" />
  <path d="M9 21v-6a2 2 0 0 1 2 -2h2a2 2 0 0 1 2 2v6" />
</svg>



                
                <span>主页</span>
            </a>
        </li>
        
        

        <li >
            <a href='/archives/' >
                
                
                
                    <svg xmlns="http://www.w3.org/2000/svg" class="icon icon-tabler icon-tabler-archive" width="24" height="24" viewBox="0 0 24 24" stroke-width="2" stroke="currentColor" fill="none" stroke-linecap="round" stroke-linejoin="round">
  <path stroke="none" d="M0 0h24v24H0z"/>
  <rect x="3" y="4" width="18" height="4" rx="2" />
  <path d="M5 8v10a2 2 0 0 0 2 2h10a2 2 0 0 0 2 -2v-10" />
  <line x1="10" y1="12" x2="14" y2="12" />
</svg>



                
                <span>归档</span>
            </a>
        </li>
        
        

        <li >
            <a href='/search/' >
                
                
                
                    <svg xmlns="http://www.w3.org/2000/svg" class="icon icon-tabler icon-tabler-search" width="24" height="24" viewBox="0 0 24 24" stroke-width="2" stroke="currentColor" fill="none" stroke-linecap="round" stroke-linejoin="round">
  <path stroke="none" d="M0 0h24v24H0z"/>
  <circle cx="10" cy="10" r="7" />
  <line x1="21" y1="21" x2="15" y2="15" />
</svg>



                
                <span>搜索</span>
            </a>
        </li>
        
        

        <li >
            <a href='/%E5%85%B3%E4%BA%8E/' >
                
                
                
                    <svg xmlns="http://www.w3.org/2000/svg" class="icon icon-tabler icon-tabler-link" width="24" height="24" viewBox="0 0 24 24" stroke-width="2" stroke="currentColor" fill="none" stroke-linecap="round" stroke-linejoin="round">
  <path stroke="none" d="M0 0h24v24H0z"/>
  <path d="M10 14a3.5 3.5 0 0 0 5 0l4 -4a3.5 3.5 0 0 0 -5 -5l-.5 .5" />
  <path d="M14 10a3.5 3.5 0 0 0 -5 0l-4 4a3.5 3.5 0 0 0 5 5l.5 -.5" />
</svg>



                
                <span>关于</span>
            </a>
        </li>
        

        <div class="menu-bottom-section">
                <li id="i18n-switch">  
                    <svg xmlns="http://www.w3.org/2000/svg" class="icon icon-tabler icon-tabler-language" width="24" height="24" viewBox="0 0 24 24" stroke-width="2" stroke="currentColor" fill="none" stroke-linecap="round" stroke-linejoin="round">
  <path stroke="none" d="M0 0h24v24H0z" fill="none"/>
  <path d="M4 5h7" />
  <path d="M9 3v2c0 4.418 -2.239 8 -5 8" />
  <path d="M5 9c-.003 2.144 2.952 3.908 6.7 4" />
  <path d="M12 20l4 -9l4 9" />
  <path d="M19.1 18h-6.2" />
</svg>



                    <select name="language" onchange="window.location.href = this.selectedOptions[0].value">
                        
                            <option value="https://charent.github.io/" selected>中文</option>
                        
                            <option value="https://charent.github.io/en/" >English</option>
                        
                    </select>
                </li>
            
            
            
                <li id="dark-mode-toggle">
                    <svg xmlns="http://www.w3.org/2000/svg" class="icon icon-tabler icon-tabler-toggle-left" width="24" height="24" viewBox="0 0 24 24" stroke-width="2" stroke="currentColor" fill="none" stroke-linecap="round" stroke-linejoin="round">
  <path stroke="none" d="M0 0h24v24H0z"/>
  <circle cx="8" cy="12" r="2" />
  <rect x="2" y="6" width="20" height="12" rx="6" />
</svg>



                    <svg xmlns="http://www.w3.org/2000/svg" class="icon icon-tabler icon-tabler-toggle-right" width="24" height="24" viewBox="0 0 24 24" stroke-width="2" stroke="currentColor" fill="none" stroke-linecap="round" stroke-linejoin="round">
  <path stroke="none" d="M0 0h24v24H0z"/>
  <circle cx="16" cy="12" r="2" />
  <rect x="2" y="6" width="20" height="12" rx="6" />
</svg>



                    <span>暗色模式</span>
                </li>
            
        </div>
    </ol>
</aside>
<main class="main full-width">
    <article class="main-article">
    <header class="article-header">

    <div class="article-details">
    
    <header class="article-category">
        
            <a href="/categories/%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0/" >
                深度学习
            </a>
        
            <a href="/categories/%E5%88%86%E7%B1%BB%E7%AE%97%E6%B3%95/" >
                分类算法
            </a>
        
            <a href="/categories/tensorflow/" >
                Tensorflow
            </a>
        
    </header>
    

    <div class="article-title-wrapper">
        <h2 class="article-title">
            <a href="/p/%E5%AD%AA%E7%94%9F%E7%BD%91%E7%BB%9C%E5%AE%9E%E7%8E%B0%E5%A4%9A%E5%88%86%E7%B1%BBsiamese-network/">孪生网络实现多分类（Siamese Network）</a>
        </h2>
    
        
    </div>

    
    <footer class="article-time">
        
            <div>
                <svg xmlns="http://www.w3.org/2000/svg" class="icon icon-tabler icon-tabler-calendar-time" width="56" height="56" viewBox="0 0 24 24" stroke-width="2" stroke="currentColor" fill="none" stroke-linecap="round" stroke-linejoin="round">
  <path stroke="none" d="M0 0h24v24H0z"/>
  <path d="M11.795 21h-6.795a2 2 0 0 1 -2 -2v-12a2 2 0 0 1 2 -2h12a2 2 0 0 1 2 2v4" />
  <circle cx="18" cy="18" r="4" />
  <path d="M15 3v4" />
  <path d="M7 3v4" />
  <path d="M3 11h16" />
  <path d="M18 16.496v1.504l1 1" />
</svg>
                <time class="article-time--published">2020-03-05</time>
            </div>
        

        
    </footer>
    

    
</div>
</header>

    <section class="article-content">
    
    
    <h4 id="孪生网络之小样本学习">孪生网络之小样本学习：</h4>
<ol>
<li>DL标准分类方法：输入x通过多层神经网络后，输出x属于某一类（或多个类）的概率。</li>
<li>小样本分类方法：每个类只需要为数不多的训练样本，要求要识别的样本与训练样本相似，如人脸识别。</li>
</ol>
<h4 id="孪生网络">孪生网络</h4>
<p>孪生网络和学习输入的特征进行分类的网络不同，孪生网络在两个输入中进行区分，即是学习两个输入的相似度。</p>
<p>孪生网络由两个完全相同的神经网络组成，每个都采用两个输入图像中的一个。然后将两个网络的最后一层馈送到对比损失函数，用来计算两个图像之间的相似度。它具有两个姐妹网络，它们是具有完全相同权重的相同神经网络。图像对中的每个图像将被馈送到这些网络中的一个。使用对比损失函数优化网络（我们将获得确切的函数）</p>
<p><img src="/p/%E5%AD%AA%E7%94%9F%E7%BD%91%E7%BB%9C%E5%AE%9E%E7%8E%B0%E5%A4%9A%E5%88%86%E7%B1%BBsiamese-network/1.png"
	width="514"
	height="660"
	srcset="/p/%E5%AD%AA%E7%94%9F%E7%BD%91%E7%BB%9C%E5%AE%9E%E7%8E%B0%E5%A4%9A%E5%88%86%E7%B1%BBsiamese-network/1_hu08795b9674ed7ebcd18faa6b8eccb258_71605_480x0_resize_box_3.png 480w, /p/%E5%AD%AA%E7%94%9F%E7%BD%91%E7%BB%9C%E5%AE%9E%E7%8E%B0%E5%A4%9A%E5%88%86%E7%B1%BBsiamese-network/1_hu08795b9674ed7ebcd18faa6b8eccb258_71605_1024x0_resize_box_3.png 1024w"
	loading="lazy"
	
		alt="孪生网络"
	
	
		class="gallery-image" 
		data-flex-grow="77"
		data-flex-basis="186px"
	
></p>
<h4 id="损失函数">损失函数：</h4>
<p>使用和普通神经网络所不同的损失函数：对比损失函数（Contrastive Loss Function）
孪生架构的目的不是对输入特征进行分类，而是区分它们。因此，分类损失函数（如交叉熵）不是最合适的选择。相反，这种架构更适合使用对比函数。一般而言，这个函数只是评估网络区分一对给定的图像的效果如何。
$$
L = \frac{1} {2N} \sum_{i=1}^{n} yd^2 + (1-y)max(margin-d, 0)^2
$$</p>
<p>其中$d=||an−bn||_2$，代表两个样本特征的欧氏距离，y为两个样本是否匹配的标签，<code>y=1</code>代表两个样本相似或者匹配，<code>y=0</code>则代表不匹配，<code>margin</code>为设定的阈值。
欧氏距离<code>d</code>：
$$
d=\sqrt{(G_w(x_1) - G_w(x_2))^2}
$$
其中$G_w$是其中一个姐妹网络的输出。$X1$和$X2$是输入数据对。
说明：
Y值为1或0。如果模型预测输入是相似的，那么Y的值为0，否则Y为1。
max（）是表示0和m-Dw之间较大值的函数。
m是大于0的边际价值（margin value）。有一个边际价值表示超出该边际价值的不同对不会造成损失。这是有道理的，因为你只希望基于实际不相似对来优化网络，但网络认为是相当相似的。</p>
<h4 id="处理多分类少样本的问题">处理多分类、少样本的问题：</h4>
<p>训练数据要保证基本正、负样本比例接近1:1
数据集划分：训练集，测试集，支持集（Support Set），其中支持集包含所有类的样本</p>
<h4 id="训练">训练：</h4>
<p>x1、x2可以是正样本、负样本</p>
<ol>
<li>样本x1通过网络得到输出y1</li>
<li>样本x2通过网络得到输出y2</li>
<li>使用y1、y2计算对比损失</li>
<li>反向传播计算梯度</li>
<li>使用优化器更新权重</li>
<li>训练集中，x1、x2和标签y的对应关系</li>
</ol>
<div class="table-wrapper"><table>
<thead>
<tr>
<th>x1和x2的关系</th>
<th style="text-align:center">y</th>
</tr>
</thead>
<tbody>
<tr>
<td>x1、x2属于同一类</td>
<td style="text-align:center">0</td>
</tr>
<tr>
<td>x1、x2属于不同类</td>
<td style="text-align:center">1</td>
</tr>
</tbody>
</table></div>
<h4 id="测试">测试：</h4>
<p>给定测试样本 x ，从支持集中依次取出每个类 x_i（ i=1,2,3,&hellip;.n ），x 和所有 x_i 依次通过孪生网络，若 x_j 和 x 的相似度最高，则认为 x 属于第 j 类（即 x_j 所属的类）</p>
<h4 id="实验">实验：</h4>
<h5 id="数据">数据：</h5>
<p>新闻标题数据集，14个类。小样本：每个类中随机抽取50条数据做为训练数据。</p>
<p>训练集：每个类的50条样本和自身做笛卡尔乘积，得到 50 x 50 = 2500 条正样本，即这2500条样本（x1,x2）对应的y=1；负样本：对每一类，从其余类的各50个样本中，每个类随机抽取 50 / 14 = 4个样本，共56个样本，和该类的50个样本做笛卡尔积组成 50 x 56 = 2800个负样本，即是(x1, x2)对应的 y=0，x1为该类样本，x2为其他类样本。共71400条训练数据
测试集：每个类随机选取3000条，共42000条数据。</p>
<p>支持集：支持集的选定较为困难。孪生网络在人脸识别中取得的效果非常好，网络训练好之后，直接拿一张人脸照片就可以作为支持集，支持集的样本和测试集的样本输入孪生网络后，网络会输出这两个样本的相似度，再根据相似度判断测试样本和支持样本是否属于同一个类。具体支持集的选取会在实验部分讨论。</p>
<h4 id="字向量">字向量：</h4>
<p>使用开源的基于人民日报语料、Word + Ngram训练的预训练词向量，包含1664K的字、词。每个字、词的维度是300。向量矩阵大小：972M</p>
<h4 id="模型">模型：</h4>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span><span class="lnt">4
</span><span class="lnt">5
</span><span class="lnt">6
</span><span class="lnt">7
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="n">SiameseClassifier</span><span class="p">(</span>
</span></span><span class="line"><span class="cl">  <span class="n">Conv1D</span><span class="p">(</span><span class="n">filters</span><span class="o">=</span><span class="mi">128</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span> <span class="n">strides</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s1">&#39;relu&#39;</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">  <span class="n">Conv1D</span><span class="p">(</span><span class="n">filters</span><span class="o">=</span><span class="mi">64</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">strides</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s1">&#39;relu&#39;</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">  <span class="n">Dense</span><span class="p">(</span><span class="mi">64</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s1">&#39;relu&#39;</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">  <span class="n">Dense</span><span class="p">(</span><span class="mi">32</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s1">&#39;relu&#39;</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">  <span class="n">Dense</span><span class="p">(</span><span class="mi">16</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s1">&#39;sigmoid&#39;</span><span class="p">)</span>
</span></span><span class="line"><span class="cl"><span class="p">)</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>模型最后输出一个长度为16的一维数组，用于计算相似度。</p>
<h4 id="损失函数-1">损失函数：</h4>
<p>使用对比损失函数，描述两个输出向量之间的欧氏距离的损失。</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span><span class="lnt">19
</span><span class="lnt">20
</span><span class="lnt">21
</span><span class="lnt">22
</span><span class="lnt">23
</span><span class="lnt">24
</span><span class="lnt">25
</span><span class="lnt">26
</span><span class="lnt">27
</span><span class="lnt">28
</span><span class="lnt">29
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="c1">#对比损失函数</span>
</span></span><span class="line"><span class="cl"><span class="k">class</span> <span class="nc">ContrastiveLoss</span><span class="p">(</span><span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">losses</span><span class="o">.</span><span class="n">Loss</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">margin</span><span class="o">=</span><span class="mf">2.0</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="s1">&#39;&#39;&#39;
</span></span></span><span class="line"><span class="cl"><span class="s1">        margin为边界值
</span></span></span><span class="line"><span class="cl"><span class="s1">        &#39;&#39;&#39;</span>
</span></span><span class="line"><span class="cl">        <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
</span></span><span class="line"><span class="cl">        <span class="bp">self</span><span class="o">.</span><span class="n">margin</span> <span class="o">=</span> <span class="n">margin</span>
</span></span><span class="line"><span class="cl">        
</span></span><span class="line"><span class="cl">    <span class="k">def</span> <span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">y_true</span><span class="p">,</span> <span class="n">y_pred</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="s1">&#39;&#39;&#39;
</span></span></span><span class="line"><span class="cl"><span class="s1">        output1和output2为二维数组:(batch_size, dim)
</span></span></span><span class="line"><span class="cl"><span class="s1">        &#39;&#39;&#39;</span>
</span></span><span class="line"><span class="cl">        <span class="n">output1</span> <span class="o">=</span> <span class="n">y_pred</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
</span></span><span class="line"><span class="cl">        <span class="n">output2</span> <span class="o">=</span> <span class="n">y_pred</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
</span></span><span class="line"><span class="cl">        <span class="n">label</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">cast</span><span class="p">(</span><span class="n">y_true</span><span class="p">,</span><span class="n">dtype</span><span class="o">=</span><span class="n">tf</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">        
</span></span><span class="line"><span class="cl">        <span class="c1">#计算欧氏距离 d = sqrt(sum(pow((x - y), 2))</span>
</span></span><span class="line"><span class="cl">        <span class="n">d_eu</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">tf</span><span class="o">.</span><span class="n">reduce_sum</span><span class="p">(</span><span class="n">tf</span><span class="o">.</span><span class="n">square</span><span class="p">(</span><span class="n">output1</span> <span class="o">-</span> <span class="n">output2</span><span class="p">),</span> 
</span></span><span class="line"><span class="cl">                        <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span><span class="n">keepdims</span><span class="o">=</span><span class="kc">False</span><span class="p">))</span>
</span></span><span class="line"><span class="cl">        
</span></span><span class="line"><span class="cl">        <span class="c1">#计算对比损失：</span>
</span></span><span class="line"><span class="cl">        <span class="c1">#loss= Y * (Dw)^2 / 2 + (1 - Y) * max((0, margin - Dw))^2 / 2</span>
</span></span><span class="line"><span class="cl">        <span class="c1">#其中，Dw为模型输出之间的欧氏距离，Y为标签，0或1，margin为边界值</span>
</span></span><span class="line"><span class="cl">        <span class="n">loss</span> <span class="o">=</span> <span class="p">(</span><span class="n">label</span> <span class="o">*</span> <span class="n">tf</span><span class="o">.</span><span class="n">square</span><span class="p">(</span><span class="n">d_eu</span><span class="p">)</span> <span class="o">+</span> <span class="p">(</span><span class="mf">1.0</span> <span class="o">-</span> <span class="n">label</span><span class="p">)</span> <span class="o">*</span> 
</span></span><span class="line"><span class="cl">                <span class="n">tf</span><span class="o">.</span><span class="n">square</span><span class="p">(</span><span class="n">tf</span><span class="o">.</span><span class="n">maximum</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">margin</span> <span class="o">-</span> <span class="n">d_eu</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">)))</span> <span class="o">/</span> <span class="mi">2</span>
</span></span><span class="line"><span class="cl">                 
</span></span><span class="line"><span class="cl">        <span class="c1">#返回的loss会被自动reduce_mean</span>
</span></span><span class="line"><span class="cl">        <span class="k">return</span> <span class="n">loss</span>
</span></span></code></pre></td></tr></table>
</div>
</div><h4 id="正确率的计算">正确率的计算：</h4>
<p>损失函数中使用欧氏距离来刻画两个向量之间的相似程度，欧氏距离的值域范围是[0,+∞]，并不适合设置一个阀值来衡量相似或者不相似。解决方法是将欧氏距离映射到[0, 1]的区间（归一化），和余弦相似度的值域一样，接近1就越相似，接近0就越不相似。</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="k">def</span> <span class="nf">euclidean_distance</span><span class="p">(</span><span class="n">x1</span><span class="p">,</span> <span class="n">x2</span><span class="p">,</span> <span class="n">keepdims</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="s1">&#39;&#39;&#39;
</span></span></span><span class="line"><span class="cl"><span class="s1">    计算两个tensor的欧氏距离
</span></span></span><span class="line"><span class="cl"><span class="s1">    &#39;&#39;&#39;</span>
</span></span><span class="line"><span class="cl">    <span class="k">return</span> <span class="n">tf</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">tf</span><span class="o">.</span><span class="n">reduce_sum</span><span class="p">(</span><span class="n">tf</span><span class="o">.</span><span class="n">square</span><span class="p">(</span><span class="n">x1</span> <span class="o">-</span> <span class="n">x2</span><span class="p">),</span> 
</span></span><span class="line"><span class="cl">                    <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span><span class="n">keepdims</span><span class="o">=</span><span class="n">keepdims</span><span class="p">))</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl"><span class="k">def</span> <span class="nf">euclidean_similarity</span><span class="p">(</span><span class="n">x1</span><span class="p">,</span> <span class="n">x2</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">    <span class="s1">&#39;&#39;&#39;
</span></span></span><span class="line"><span class="cl"><span class="s1">    将两个tensor之间的欧氏距离转化为相似度，主要是一个归一化的操作
</span></span></span><span class="line"><span class="cl"><span class="s1">    &#39;&#39;&#39;</span>
</span></span><span class="line"><span class="cl">    <span class="n">d</span> <span class="o">=</span> <span class="n">euclidean_distance</span><span class="p">(</span><span class="n">x1</span><span class="p">,</span> <span class="n">x2</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    <span class="n">s</span> <span class="o">=</span> <span class="mf">1.0</span> <span class="o">/</span> <span class="p">(</span><span class="mf">1.0</span> <span class="o">+</span> <span class="n">d</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">    
</span></span><span class="line"><span class="cl">    <span class="k">return</span> <span class="n">s</span>
</span></span></code></pre></td></tr></table>
</div>
</div><p>在计算正确率的时候，将测试样本x1，支持集样本x2（n个类就有n个x2）依次通过孪生网络，得到n个一维相似度数组，最后做argmax运算即可得出x1是属于n个类中的第几个类。</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt"> 1
</span><span class="lnt"> 2
</span><span class="lnt"> 3
</span><span class="lnt"> 4
</span><span class="lnt"> 5
</span><span class="lnt"> 6
</span><span class="lnt"> 7
</span><span class="lnt"> 8
</span><span class="lnt"> 9
</span><span class="lnt">10
</span><span class="lnt">11
</span><span class="lnt">12
</span><span class="lnt">13
</span><span class="lnt">14
</span><span class="lnt">15
</span><span class="lnt">16
</span><span class="lnt">17
</span><span class="lnt">18
</span><span class="lnt">19
</span><span class="lnt">20
</span><span class="lnt">21
</span><span class="lnt">22
</span><span class="lnt">23
</span><span class="lnt">24
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-python" data-lang="python"><span class="line"><span class="cl"><span class="k">with</span> <span class="n">tf</span><span class="o">.</span><span class="n">device</span><span class="p">(</span><span class="s1">&#39;/gpu:0&#39;</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">        <span class="k">for</span> <span class="n">x</span><span class="p">,</span><span class="n">y</span> <span class="ow">in</span> <span class="n">test_iterator</span><span class="p">:</span>
</span></span><span class="line"><span class="cl">            <span class="n">class_probability</span> <span class="o">=</span> <span class="p">[]</span>
</span></span><span class="line"><span class="cl">           
</span></span><span class="line"><span class="cl">            <span class="n">y1</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">call_once</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">            
</span></span><span class="line"><span class="cl">            <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">support_n</span><span class="p">):</span>
</span></span><span class="line"><span class="cl">                <span class="n">x2</span> <span class="o">=</span> <span class="n">support_vector</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>
</span></span><span class="line"><span class="cl">                <span class="c1">#对y2进行batch_size复制</span>
</span></span><span class="line"><span class="cl">                <span class="n">x2</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">expand_dims</span><span class="p">(</span><span class="n">x2</span><span class="p">,</span><span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">                <span class="n">x2</span> <span class="o">=</span> <span class="n">a</span><span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">repeat</span><span class="p">(</span><span class="n">x2</span><span class="p">,</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">                
</span></span><span class="line"><span class="cl">                <span class="n">y2</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">call_once</span><span class="p">(</span><span class="n">x2</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">                
</span></span><span class="line"><span class="cl">                <span class="c1">#计算相似度</span>
</span></span><span class="line"><span class="cl">                <span class="n">s</span> <span class="o">=</span> <span class="n">euclidean_similarity</span><span class="p">(</span><span class="n">y1</span><span class="p">,</span><span class="n">y2</span><span class="p">)</span> <span class="c1">#shape: (batch_size,)</span>
</span></span><span class="line"><span class="cl">                <span class="n">class_probability</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">s</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">                
</span></span><span class="line"><span class="cl"><span class="c1">#             print(np.array(class_probability).shape)</span>
</span></span><span class="line"><span class="cl">            
</span></span><span class="line"><span class="cl">            <span class="n">y_pred</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">class_probability</span><span class="p">,</span> 
</span></span><span class="line"><span class="cl">                                <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
</span></span><span class="line"><span class="cl">
</span></span><span class="line"><span class="cl">            <span class="n">test_acc</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">y_pred</span> <span class="o">==</span> <span class="n">y</span><span class="p">)</span> <span class="o">/</span> <span class="n">y</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
</span></span></code></pre></td></tr></table>
</div>
</div><h4 id="训练-1">训练：</h4>
<p>机器配置：CPU：i5-8300H；显卡：1050TI；内存：16G；显存：4G</p>
<div class="highlight"><div class="chroma">
<table class="lntable"><tr><td class="lntd">
<pre tabindex="0" class="chroma"><code><span class="lnt">1
</span><span class="lnt">2
</span><span class="lnt">3
</span></code></pre></td>
<td class="lntd">
<pre tabindex="0" class="chroma"><code class="language-bash" data-lang="bash"><span class="line"><span class="cl"><span class="nv">BATCH_SIZE</span> <span class="o">=</span> <span class="m">32</span>
</span></span><span class="line"><span class="cl"><span class="nv">EPOCH</span> <span class="o">=</span> <span class="m">2</span>
</span></span><span class="line"><span class="cl"><span class="nv">LEARNING_RATE</span> <span class="o">=</span> 0.001
</span></span></code></pre></td></tr></table>
</div>
</div><ol>
<li>使用Adam优化器。
训练时，让输入x1和x2分别通过模型，得到两个输出y1和y2，再用y1和y2计算损失和正确率。
训练损失：
<img src="/p/%E5%AD%AA%E7%94%9F%E7%BD%91%E7%BB%9C%E5%AE%9E%E7%8E%B0%E5%A4%9A%E5%88%86%E7%B1%BBsiamese-network/2.png"
	width="442"
	height="278"
	srcset="/p/%E5%AD%AA%E7%94%9F%E7%BD%91%E7%BB%9C%E5%AE%9E%E7%8E%B0%E5%A4%9A%E5%88%86%E7%B1%BBsiamese-network/2_huf92c6438d52cf46256b6f5d441c5b0e5_8667_480x0_resize_box_3.png 480w, /p/%E5%AD%AA%E7%94%9F%E7%BD%91%E7%BB%9C%E5%AE%9E%E7%8E%B0%E5%A4%9A%E5%88%86%E7%B1%BBsiamese-network/2_huf92c6438d52cf46256b6f5d441c5b0e5_8667_1024x0_resize_box_3.png 1024w"
	loading="lazy"
	
		alt="loss"
	
	
		class="gallery-image" 
		data-flex-grow="158"
		data-flex-basis="381px"
	
>
训练正确率：
<img src="/p/%E5%AD%AA%E7%94%9F%E7%BD%91%E7%BB%9C%E5%AE%9E%E7%8E%B0%E5%A4%9A%E5%88%86%E7%B1%BBsiamese-network/3.png"
	width="413"
	height="260"
	srcset="/p/%E5%AD%AA%E7%94%9F%E7%BD%91%E7%BB%9C%E5%AE%9E%E7%8E%B0%E5%A4%9A%E5%88%86%E7%B1%BBsiamese-network/3_hu10b65b6fb176294d368b185a4436db70_7815_480x0_resize_box_3.png 480w, /p/%E5%AD%AA%E7%94%9F%E7%BD%91%E7%BB%9C%E5%AE%9E%E7%8E%B0%E5%A4%9A%E5%88%86%E7%B1%BBsiamese-network/3_hu10b65b6fb176294d368b185a4436db70_7815_1024x0_resize_box_3.png 1024w"
	loading="lazy"
	
		alt="acc"
	
	
		class="gallery-image" 
		data-flex-grow="158"
		data-flex-basis="381px"
	
></li>
</ol>
<h4 id="测试-1">测试：</h4>
<p>每个类的训练样本只有50个，训练时候正确率已经接近100%（过拟合问题后面讨论）。但是测试集的正确率训练的过程中的正确率相差甚远。
（支持集的作用是当一个类示例，看测试样本和示例样本是否相似）
将训练数据（14个类 x 每个类50条样本）求和作为支持集：
平均正确率只有7%，出现了某一批次的测试数据正确率达100%，其余测试数据正确率为0，原因未知。
<img src="/p/%E5%AD%AA%E7%94%9F%E7%BD%91%E7%BB%9C%E5%AE%9E%E7%8E%B0%E5%A4%9A%E5%88%86%E7%B1%BBsiamese-network/4.png"
	width="605"
	height="343"
	srcset="/p/%E5%AD%AA%E7%94%9F%E7%BD%91%E7%BB%9C%E5%AE%9E%E7%8E%B0%E5%A4%9A%E5%88%86%E7%B1%BBsiamese-network/4_hu02f7e38ad13a14933311fcf812b450c5_13449_480x0_resize_box_3.png 480w, /p/%E5%AD%AA%E7%94%9F%E7%BD%91%E7%BB%9C%E5%AE%9E%E7%8E%B0%E5%A4%9A%E5%88%86%E7%B1%BBsiamese-network/4_hu02f7e38ad13a14933311fcf812b450c5_13449_1024x0_resize_box_3.png 1024w"
	loading="lazy"
	
		alt="acc"
	
	
		class="gallery-image" 
		data-flex-grow="176"
		data-flex-basis="423px"
	
></p>
<h3 id="原因分析">原因分析：</h3>
<p>本实验每个类别使用50个样本，满足小样本的特性，但是模型训练完成之后，用测试集测试模型，正确率不到40%。孪生网络应用在人脸识别模型较多（即是输入两张人脸，看这两张人脸是不是属于同一个人），训练一个模型只需要少量样本。人脸图像的特征较为明显（有鼻子、有眼睛等），图像特征也相对容易捕捉。但是当将孪生网络用于少样本的文本分类任务后，一个类别的文本表述方式千变万化，很难通过少量样本找到该类文本的明显特征。
比如，对于财经类的文本：“铜价上涨趋势不变 短期震荡后进一步上扬”出现在了训练集当中，假设这个样本也是财经类的支持集样本，当测试集遇到“午评：期市全线反弹 有色金属郑糖领涨”，出现了“铜”、“有色金属”的特征，则孪生网络可以认为属于同一个类别；但是当测试集样本“保命比业绩更重要 基金上演熊市年底冲刺”，当“基金”、“熊市”等特征没有出现在训练集时，孪生网络则不能正确地划分该类。</p>
<p>综上所述，孪生网络用于多类别文本分类不可行。</p>

</section>


    <footer class="article-footer">
    

    </footer>


    
        <link 
                rel="stylesheet" 
                href="https://cdn.jsdelivr.net/npm/katex@0.13.13/dist/katex.min.css"integrity="sha384-RZU/ijkSsFbcmivfdRBQDtwuwVqK7GMOw6IMvKyeWL2K5UAlyp6WonmB8m7Jd0Hn"crossorigin="anonymous"
            ><script 
                src="https://cdn.jsdelivr.net/npm/katex@0.13.13/dist/katex.min.js"integrity="sha384-pK1WpvzWVBQiP0/GjnvRxV4mOb0oxFuyRxJlk6vVw146n3egcN5C925NCP7a7BY8"crossorigin="anonymous"
                defer
                >
            </script><script 
                src="https://cdn.jsdelivr.net/npm/katex@0.13.13/dist/contrib/auto-render.min.js"integrity="sha384-vZTG03m&#43;2yp6N6BNi5iM4rW4oIwk5DfcNdFfxkk9ZWpDriOkXX8voJBFrAO7MpVl"crossorigin="anonymous"
                defer
                >
            </script><script>
    window.addEventListener("DOMContentLoaded", () => {
        renderMathInElement(document.querySelector(`.article-content`), {
            delimiters: [
                { left: "$$", right: "$$", display: true },
                { left: "$", right: "$", display: false },
                { left: "\\(", right: "\\)", display: false },
                { left: "\\[", right: "\\]", display: true }
            ]
        });})
</script>
    
</article>

    

    

<aside class="related-contents--wrapper">
    <h2 class="section-title">相关文章</h2>
    <div class="related-contents">
        <div class="flex article-list--tile">
            
                
<article class="">
    <a href="/p/%E4%BD%BF%E7%94%A8%E4%BA%8C%E5%88%86%E7%B1%BB%E5%AE%9E%E7%8E%B0%E5%A4%9A%E5%88%86%E7%B1%BB/">
        
        

        <div class="article-details">
            <h2 class="article-title">使用二分类实现多分类</h2>
        </div>
    </a>
</article>
            
                
<article class="">
    <a href="/p/%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0%E5%8F%8D%E5%90%91%E4%BC%A0%E6%92%AD%E6%8E%A8%E5%AF%BC/">
        
        

        <div class="article-details">
            <h2 class="article-title">深度学习反向传播推导</h2>
        </div>
    </a>
</article>
            
                
<article class="">
    <a href="/p/pytorch%E5%8A%A0%E8%BD%BD%E5%A4%A7%E6%95%B0%E6%8D%AE%E9%9B%86/">
        
        

        <div class="article-details">
            <h2 class="article-title">pytorch加载大数据集</h2>
        </div>
    </a>
</article>
            
                
<article class="">
    <a href="/p/pytorch%E5%A4%A7%E6%95%B0%E6%8D%AE%E9%9B%86num_workers%E5%A4%A7%E4%BA%8E1%E5%86%85%E5%AD%98%E7%BC%93%E6%85%A2%E5%A2%9E%E6%B6%A8/">
        
        

        <div class="article-details">
            <h2 class="article-title">Pytorch大数据集num_workers大于1内存缓慢增涨</h2>
        </div>
    </a>
</article>
            
                
<article class="">
    <a href="/p/%E5%B0%8F%E6%A0%B7%E6%9C%AC%E5%AD%A6%E4%B9%A0/">
        
        

        <div class="article-details">
            <h2 class="article-title">小样本学习</h2>
        </div>
    </a>
</article>
            
        </div>
    </div>
</aside>

     
    
        
    

    <footer class="site-footer">
    <section class="copyright">
        &copy; 
        
            2018 - 
        
        2024 Charent的博客
    </section>
    
    <section class="powerby">
         <br />
        
    </section>
</footer>


    
<div class="pswp" tabindex="-1" role="dialog" aria-hidden="true">

    
    <div class="pswp__bg"></div>

    
    <div class="pswp__scroll-wrap">

        
        <div class="pswp__container">
            <div class="pswp__item"></div>
            <div class="pswp__item"></div>
            <div class="pswp__item"></div>
        </div>

        
        <div class="pswp__ui pswp__ui--hidden">

            <div class="pswp__top-bar">

                

                <div class="pswp__counter"></div>

                <button class="pswp__button pswp__button--close" title="Close (Esc)"></button>

                <button class="pswp__button pswp__button--share" title="Share"></button>

                <button class="pswp__button pswp__button--fs" title="Toggle fullscreen"></button>

                <button class="pswp__button pswp__button--zoom" title="Zoom in/out"></button>

                
                
                <div class="pswp__preloader">
                    <div class="pswp__preloader__icn">
                        <div class="pswp__preloader__cut">
                            <div class="pswp__preloader__donut"></div>
                        </div>
                    </div>
                </div>
            </div>

            <div class="pswp__share-modal pswp__share-modal--hidden pswp__single-tap">
                <div class="pswp__share-tooltip"></div>
            </div>

            <button class="pswp__button pswp__button--arrow--left" title="Previous (arrow left)">
            </button>

            <button class="pswp__button pswp__button--arrow--right" title="Next (arrow right)">
            </button>

            <div class="pswp__caption">
                <div class="pswp__caption__center"></div>
            </div>

        </div>

    </div>

</div><script 
                src="https://cdn.jsdelivr.net/npm/photoswipe@4.1.3/dist/photoswipe.min.js"integrity="sha256-ePwmChbbvXbsO02lbM3HoHbSHTHFAeChekF1xKJdleo="crossorigin="anonymous"
                defer
                >
            </script><script 
                src="https://cdn.jsdelivr.net/npm/photoswipe@4.1.3/dist/photoswipe-ui-default.min.js"integrity="sha256-UKkzOn/w1mBxRmLLGrSeyB4e1xbrp4xylgAWb3M42pU="crossorigin="anonymous"
                defer
                >
            </script><link 
                rel="stylesheet" 
                href="https://cdn.jsdelivr.net/npm/photoswipe@4.1.3/dist/default-skin/default-skin.css"integrity="sha256-c0uckgykQ9v5k&#43;IqViZOZKc47Jn7KQil4/MP3ySA3F8="crossorigin="anonymous"
            ><link 
                rel="stylesheet" 
                href="https://cdn.jsdelivr.net/npm/photoswipe@4.1.3/dist/photoswipe.css"integrity="sha256-SBLU4vv6CA6lHsZ1XyTdhyjJxCjPif/TRkjnsyGAGnE="crossorigin="anonymous"
            >

            </main>
    <aside class="sidebar right-sidebar sticky">
        
            
                
    <section class="widget archives">
        <div class="widget-icon">
            <svg xmlns="http://www.w3.org/2000/svg" class="icon icon-tabler icon-tabler-hash" width="24" height="24" viewBox="0 0 24 24" stroke-width="2" stroke="currentColor" fill="none" stroke-linecap="round" stroke-linejoin="round">
  <path stroke="none" d="M0 0h24v24H0z"/>
  <line x1="5" y1="9" x2="19" y2="9" />
  <line x1="5" y1="15" x2="19" y2="15" />
  <line x1="11" y1="4" x2="7" y2="20" />
  <line x1="17" y1="4" x2="13" y2="20" />
</svg>



        </div>
        <h2 class="widget-title section-title">目录</h2>
        
        <div class="widget--toc">
            <nav id="TableOfContents">
  <ol>
    <li>
      <ol>
        <li>
          <ol>
            <li><a href="#孪生网络之小样本学习">孪生网络之小样本学习：</a></li>
            <li><a href="#孪生网络">孪生网络</a></li>
            <li><a href="#损失函数">损失函数：</a></li>
            <li><a href="#处理多分类少样本的问题">处理多分类、少样本的问题：</a></li>
            <li><a href="#训练">训练：</a></li>
            <li><a href="#测试">测试：</a></li>
            <li><a href="#实验">实验：</a></li>
            <li><a href="#字向量">字向量：</a></li>
            <li><a href="#模型">模型：</a></li>
            <li><a href="#损失函数-1">损失函数：</a></li>
            <li><a href="#正确率的计算">正确率的计算：</a></li>
            <li><a href="#训练-1">训练：</a></li>
            <li><a href="#测试-1">测试：</a></li>
          </ol>
        </li>
        <li><a href="#原因分析">原因分析：</a></li>
      </ol>
    </li>
  </ol>
</nav>
        </div>
    </section>

            
        
    </aside>


        </div>
        <script 
                src="https://cdn.jsdelivr.net/npm/node-vibrant@3.1.5/dist/vibrant.min.js"integrity="sha256-5NovOZc4iwiAWTYIFiIM7DxKUXKWvpVEuMEPLzcm5/g="crossorigin="anonymous"
                
                >
            </script><script type="text/javascript" src="/ts/main.js" defer></script>
<script>
    (function () {
        const customFont = document.createElement('link');
        customFont.href = "https://fonts.googleapis.com/css2?family=Lato:wght@300;400;700&display=swap";

        customFont.type = "text/css";
        customFont.rel = "stylesheet";

        document.head.appendChild(customFont);
    }());
</script>

    </body>
</html>
